import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class LeNet(nn.Module):     # LeNet-5 for MNIST dataset
    def __init__(self, num_classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, 1, 0)    # in_channel, out_channels, kernel_size, stride, padding
        self.conv2 = nn.Conv2d(6, 16, 5, 1, 0)
        self.conv3 = nn.Conv2d(16, 120, 4, 1, 0)
        self.fc1 = nn.Linear(120, 82)
        self.fc2 = nn.Linear(82, num_classes)
        self.tanh = nn.Tanh()
        self.avgpool = nn.AvgPool2d(kernel_size = 2, stride = 2)

    def forward(self, x):
        x = self.tanh(self.conv1(x))
        x = self.avgpool(x)
        x = self.tanh(self.conv2(x))
        x = self.avgpool(x)
        x = self.tanh(self.conv3(x))
        feature = x.reshape(x.shape[0], -1)
        x_final = self.tanh(self.fc1(feature))
        predict = self.fc2(x_final)  
        return predict, feature, x_final

    def get_embedding_dim(self):
        """
        return dimension of feature at last layer, for gradient embedding
        """
        return 82

class VGG16(nn.Module):    # VGG-16 for Cifar-10 dataset
    def __init__(self, num_classes):
        super(VGG16, self).__init__()
        self.conv1_1 = nn.Conv2d(3, 64, 3, 1, 1)    # in_channel, out_channels, kernel_size, stride, padding
        self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1)
        self.batchnorm1 = nn.BatchNorm2d(64)

        self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 1)
        self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1)
        self.batchnorm2 = nn.BatchNorm2d(128)

        self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1)
        self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1)
        self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1)
        self.batchnorm3 = nn.BatchNorm2d(256)

        self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1)
        self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1)
        self.batchnorm4 = nn.BatchNorm2d(512)

        self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv5_2 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv5_3 = nn.Conv2d(512, 512, 3, 1, 1)
        self.batchnorm5 = nn.BatchNorm2d(512)

        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.pooling = nn.MaxPool2d(kernel_size = 2, stride = 2)

    def forward(self, x):
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.batchnorm1(self.conv1_2(x)))
        x = self.pooling(x)
        x = F.relu(self.conv2_1(x))
        x = F.relu(self.batchnorm2(self.conv2_2(x)))
        x = self.pooling(x)
        x = F.relu(self.conv3_1(x))
        x = F.relu(self.conv3_2(x))
        x = F.relu(self.batchnorm3(self.conv3_3(x)))
        x = self.pooling(x)
        x = F.relu(self.conv4_1(x))
        x = F.relu(self.conv4_2(x))
        x = F.relu(self.batchnorm4(self.conv4_3(x)))
        x = self.pooling(x)
        x = F.relu(self.conv5_1(x))
        x = F.relu(self.conv5_2(x))
        x = F.relu(self.batchnorm5(self.conv5_3(x)))
        x = self.pooling(x)
        feature = x.view(x.shape[0], -1)
        feature = F.relu(self.fc1(feature))
        predict = self.fc2(feature)   
        return predict, feature, feature

    def get_embedding_dim(self):
        """
        return dimension of feature at last layer, for gradient embedding
        """
        return 256
